use std::mem;
use std::mem::size_of;
use windows::Win32::Foundation::HANDLE;
use windows::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW};
use windows::core::HSTRING;

use crate::classes::*;
use crate::enums::*;
use crate::{ProcessInfo, ProcessMitigations};

struct BitVector32 {
    data: i32,
}

impl BitVector32 {
    fn new(data: i32) -> Self {
        Self { data }
    }

    fn create_mask() -> i32 {
        1
    }

    fn next_mask(previous: i32) -> i32 {
        previous << 1
    }

    fn get(&self, mask: i32) -> bool {
        (self.data & mask) != 0
    }
}

pub struct GetProcessMitigationCommand;

impl GetProcessMitigationCommand {
    pub fn get_process_mitigation_policy(
        handle: HANDLE,
        policy: ProcessMitigationPolicy,
        buffer: &mut [u8],
    ) -> Result<(), windows::core::Error> {
        unsafe {
            let kernel32 = LoadLibraryW(&HSTRING::from("kernel32.dll"))?;
            let proc_addr =
                GetProcAddress(kernel32, windows::core::s!("GetProcessMitigationPolicy"));
            if let Some(proc_addr) = proc_addr {
                let func: extern "system" fn(HANDLE, u32, *mut u8, u32) -> i32 =
                    mem::transmute(proc_addr);
                let result = func(
                    handle,
                    policy as u32,
                    buffer.as_mut_ptr(),
                    buffer.len() as u32,
                );
                if result != 0 {
                    Ok(())
                } else {
                    Err(windows::core::Error::from_win32())
                }
            } else {
                Err(windows::core::Error::from_win32())
            }
        }
    }

    pub fn export_mitigation(xml_path: &str) -> Result<(), String> {
        unsafe {
            let dll = LoadLibraryW(&HSTRING::from("MitigationConfiguration.dll"))
                .map_err(|_| "Failed to load MitigationConfiguration.dll")?;
            let proc_addr = GetProcAddress(dll, windows::core::s!("ExportMitigation"));
            if let Some(proc_addr) = proc_addr {
                let func: extern "system" fn(*const u16) -> i32 = mem::transmute(proc_addr);
                let wide: Vec<u16> = xml_path.encode_utf16().chain(std::iter::once(0)).collect();
                let code = func(wide.as_ptr());
                if code >= 0 {
                    Ok(())
                } else {
                    Err(format!("Export failed with code: {}", code))
                }
            } else {
                Err("ExportMitigation function not found".to_string())
            }
        }
    }

    /// Reads all mitigation fields from a running process handle into a ProcessMitigations struct.
    pub fn get_policy_from_running_process(
        process: &ProcessInfo,
    ) -> Result<ProcessMitigations, String> {
        // Initialize with defaults
        let mut from_running_process = ProcessMitigations::new(&process.process_name);
        from_running_process.source = "Running Process".to_string();
        from_running_process.id = process.id;

        unsafe {
            // OpenProcess for query
            let process_handle = windows::Win32::System::Threading::OpenProcess(
                windows::Win32::System::Threading::PROCESS_QUERY_INFORMATION,
                false,
                process.id,
            )
            .map_err(|e| format!("Failed to open process: {:?}", e))?;

            //
            // DEP Policy (8 bytes)
            //
            let mut buf_dep = [0u8; 8];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessDEPPolicy,
                &mut buf_dep,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes([buf_dep[0], buf_dep[1], buf_dep[2], buf_dep[3]]);
                let bits = BitVector32::new(val);
                let mask1 = BitVector32::create_mask();
                let mask2 = BitVector32::next_mask(mask1);

                if let Some(dep_policy) = from_running_process.policies.get_mut(&PolicyName::DEP) {
                    dep_policy.set_policy(
                        PolicyOptionName::Enable,
                        Policy::bool_to_option_value(bits.get(mask1)),
                    );
                    dep_policy.set_policy(
                        PolicyOptionName::EmulateAtlThunks,
                        Policy::bool_to_option_value(bits.get(mask2)),
                    );
                }
            } else {
                // default ON for 64-bit runtime
                if let Some(dep_policy) = from_running_process.policies.get_mut(&PolicyName::DEP) {
                    let is_x64 = size_of::<usize>() == 8;
                    dep_policy.set_policy(
                        PolicyOptionName::Enable,
                        Policy::bool_to_option_value(is_x64),
                    );
                }
            }

            //
            // ASLR Policy (4 bytes)
            //
            let mut buf_aslr = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessASLRPolicy,
                &mut buf_aslr,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_aslr);
                let bits = BitVector32::new(val);
                let mask3 = BitVector32::create_mask();
                let mask4 = BitVector32::next_mask(mask3);
                let mask5 = BitVector32::next_mask(mask4);
                let mask6 = BitVector32::next_mask(mask5);

                if let Some(aslr) = from_running_process.policies.get_mut(&PolicyName::ASLR) {
                    aslr.set_policy(
                        PolicyOptionName::BottomUp,
                        Policy::bool_to_option_value(bits.get(mask3)),
                    );
                    aslr.set_policy(
                        PolicyOptionName::ForceRelocateImages,
                        Policy::bool_to_option_value(bits.get(mask4)),
                    );
                    aslr.set_policy(
                        PolicyOptionName::HighEntropy,
                        Policy::bool_to_option_value(bits.get(mask5)),
                    );
                    aslr.set_policy(
                        PolicyOptionName::RequireInfo,
                        Policy::bool_to_option_value(bits.get(mask6)),
                    );
                }
            }

            //
            // StrictHandle Policy (4 bytes)
            //
            let mut buf_strict = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessStrictHandleCheckPolicy,
                &mut buf_strict,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_strict);
                let bits = BitVector32::new(val);
                let m = BitVector32::create_mask();
                if let Some(strict) = from_running_process
                    .policies
                    .get_mut(&PolicyName::StrictHandle)
                {
                    strict.set_policy(
                        PolicyOptionName::Enable,
                        Policy::bool_to_option_value(bits.get(m)),
                    );
                }
            }

            //
            // SystemCall Policy (4 bytes)
            //
            let mut buf_sys = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessSystemCallDisablePolicy,
                &mut buf_sys,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_sys);
                let bits = BitVector32::new(val);
                let m7 = BitVector32::create_mask();
                let m8 = BitVector32::next_mask(m7);
                let m9 = BitVector32::next_mask(m8);
                let m10 = BitVector32::next_mask(m9);

                if let Some(sys) = from_running_process
                    .policies
                    .get_mut(&PolicyName::SystemCalls)
                {
                    sys.set_policy(
                        PolicyOptionName::DisableWin32kSystemCalls,
                        Policy::bool_to_option_value(bits.get(m7)),
                    );
                    sys.set_policy(
                        PolicyOptionName::Audit,
                        Policy::bool_to_option_value(bits.get(m8)),
                    );
                    sys.set_policy(
                        PolicyOptionName::DisableFsctlSystemCalls,
                        Policy::bool_to_option_value(bits.get(m9)),
                    );
                    sys.set_policy(
                        PolicyOptionName::AuditFsctlSystemCalls,
                        Policy::bool_to_option_value(bits.get(m10)),
                    );
                }
            }

            //
            // ExtensionPoint Policy (4 bytes)
            //
            let mut buf_ext = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessExtensionPointDisablePolicy,
                &mut buf_ext,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_ext);
                let bits = BitVector32::new(val);
                let m = BitVector32::create_mask();
                if let Some(ep) = from_running_process
                    .policies
                    .get_mut(&PolicyName::ExtensionPoints)
                {
                    ep.set_policy(
                        PolicyOptionName::DisableExtensionPoints,
                        Policy::bool_to_option_value(bits.get(m)),
                    );
                }
            }

            //
            // DynamicCode Policy (4 bytes)
            //
            let mut buf_dyn = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessDynamicCodePolicy,
                &mut buf_dyn,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_dyn);
                let bits = BitVector32::new(val);
                let m11 = BitVector32::create_mask();
                let m12 = BitVector32::next_mask(m11);
                let m13 = BitVector32::next_mask(m12);

                if let Some(dc) = from_running_process
                    .policies
                    .get_mut(&PolicyName::DynamicCode)
                {
                    dc.set_policy(
                        PolicyOptionName::BlockDynamicCode,
                        Policy::bool_to_option_value(bits.get(m11)),
                    );
                    dc.set_policy(
                        PolicyOptionName::AllowThreadsToOptOut,
                        Policy::bool_to_option_value(bits.get(m12)),
                    );
                    dc.set_policy(
                        PolicyOptionName::Audit,
                        Policy::bool_to_option_value(bits.get(m13)),
                    );
                }
            }

            //
            // ControlFlowGuard Policy (4 bytes)
            //
            let mut buf_cfg = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessControlFlowGuardPolicy,
                &mut buf_cfg,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_cfg);
                let bits = BitVector32::new(val);
                let m14 = BitVector32::create_mask();
                let m15 = BitVector32::next_mask(m14);
                let m16 = BitVector32::next_mask(m15);

                if let Some(cfg) = from_running_process
                    .policies
                    .get_mut(&PolicyName::ControlFlowGuard)
                {
                    cfg.set_policy(
                        PolicyOptionName::Enable,
                        Policy::bool_to_option_value(bits.get(m14)),
                    );
                    cfg.set_policy(
                        PolicyOptionName::SuppressExports,
                        Policy::bool_to_option_value(bits.get(m15)),
                    );
                    cfg.set_policy(
                        PolicyOptionName::StrictControlFlowGuard,
                        Policy::bool_to_option_value(bits.get(m16)),
                    );
                }
            }

            //
            // SignedBinaries Policy (4 bytes)
            //
            let mut buf_sig = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessSignaturePolicy,
                &mut buf_sig,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_sig);
                let bits = BitVector32::new(val);
                let m17 = BitVector32::create_mask();
                let m18 = BitVector32::next_mask(m17);
                let m19 = BitVector32::next_mask(m18);
                let m20 = BitVector32::next_mask(m19);

                if let Some(sig) = from_running_process
                    .policies
                    .get_mut(&PolicyName::SignedBinaries)
                {
                    sig.set_policy(
                        PolicyOptionName::MicrosoftSignedOnly,
                        Policy::bool_to_option_value(bits.get(m17)),
                    );
                    sig.set_policy(
                        PolicyOptionName::AllowStoreSignedBinaries,
                        Policy::bool_to_option_value(bits.get(m18)),
                    );
                    sig.set_policy(
                        PolicyOptionName::Audit,
                        Policy::bool_to_option_value(bits.get(m19)),
                    );
                    sig.set_policy(
                        PolicyOptionName::AuditStoreSigned,
                        Policy::bool_to_option_value(bits.get(m20)),
                    );
                }
            }

            //
            // FontDisable Policy (4 bytes)
            //
            let mut buf_font = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessFontDisablePolicy,
                &mut buf_font,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_font);
                let bits = BitVector32::new(val);
                let m21 = BitVector32::create_mask();
                let m22 = BitVector32::next_mask(m21);

                if let Some(font) = from_running_process.policies.get_mut(&PolicyName::Fonts) {
                    font.set_policy(
                        PolicyOptionName::DisableNonSystemFonts,
                        Policy::bool_to_option_value(bits.get(m21)),
                    );
                    font.set_policy(
                        PolicyOptionName::Audit,
                        Policy::bool_to_option_value(bits.get(m22)),
                    );
                }
            }

            //
            // ImageLoad Policy (4 bytes)
            //
            let mut buf_img = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessImageLoadPolicy,
                &mut buf_img,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_img);
                let bits = BitVector32::new(val);
                let m23 = BitVector32::create_mask();
                let m24 = BitVector32::next_mask(m23);
                let m25 = BitVector32::next_mask(m24);
                let m26 = BitVector32::next_mask(m25);
                let m27 = BitVector32::next_mask(m26);
                let m28 = BitVector32::next_mask(m27);

                if let Some(img) = from_running_process
                    .policies
                    .get_mut(&PolicyName::ImageLoad)
                {
                    img.set_policy(
                        PolicyOptionName::BlockRemoteImageLoads,
                        Policy::bool_to_option_value(bits.get(m23)),
                    );
                    img.set_policy(
                        PolicyOptionName::AuditRemoteImageLoads,
                        Policy::bool_to_option_value(bits.get(m26)),
                    );
                    img.set_policy(
                        PolicyOptionName::BlockLowLabelImageLoads,
                        Policy::bool_to_option_value(bits.get(m24)),
                    );
                    img.set_policy(
                        PolicyOptionName::AuditLowLabelImageLoads,
                        Policy::bool_to_option_value(bits.get(m27)),
                    );
                    img.set_policy(
                        PolicyOptionName::PreferSystem32,
                        Policy::bool_to_option_value(bits.get(m25)),
                    );
                    img.set_policy(
                        PolicyOptionName::AuditPreferSystem32,
                        Policy::bool_to_option_value(bits.get(m28)),
                    );
                }
            }

            //
            // Payload Policy (4 bytes)
            //
            let mut buf_pay = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessPayloadRestrictionPolicy,
                &mut buf_pay,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_pay);
                let bits = BitVector32::new(val);
                let m29 = BitVector32::create_mask();
                let m30 = BitVector32::next_mask(m29);
                let m31 = BitVector32::next_mask(m30);
                let m32 = BitVector32::next_mask(m31);
                let m33 = BitVector32::next_mask(m32);
                let m34 = BitVector32::next_mask(m33);
                let m35 = BitVector32::next_mask(m34);
                let m36 = BitVector32::next_mask(m35);
                let m37 = BitVector32::next_mask(m36);
                let m38 = BitVector32::next_mask(m37);
                let m39 = BitVector32::next_mask(m38);
                let m40 = BitVector32::next_mask(m39);

                if let Some(pay) = from_running_process.policies.get_mut(&PolicyName::Payload) {
                    pay.set_policy(
                        PolicyOptionName::EnableExportAddressFilter,
                        Policy::bool_to_option_value(bits.get(m29)),
                    );
                    pay.set_policy(
                        PolicyOptionName::AuditEnableExportAddressFilter,
                        Policy::bool_to_option_value(bits.get(m30)),
                    );
                    pay.set_policy(
                        PolicyOptionName::EnableExportAddressFilterPlus,
                        Policy::bool_to_option_value(bits.get(m31)),
                    );
                    pay.set_policy(
                        PolicyOptionName::AuditEnableExportAddressFilterPlus,
                        Policy::bool_to_option_value(bits.get(m32)),
                    );
                    pay.set_policy(
                        PolicyOptionName::EnableImportAddressFilter,
                        Policy::bool_to_option_value(bits.get(m33)),
                    );
                    pay.set_policy(
                        PolicyOptionName::AuditEnableImportAddressFilter,
                        Policy::bool_to_option_value(bits.get(m34)),
                    );
                    pay.set_policy(
                        PolicyOptionName::EnableRopStackPivot,
                        Policy::bool_to_option_value(bits.get(m35)),
                    );
                    pay.set_policy(
                        PolicyOptionName::AuditEnableRopStackPivot,
                        Policy::bool_to_option_value(bits.get(m36)),
                    );
                    pay.set_policy(
                        PolicyOptionName::EnableRopCallerCheck,
                        Policy::bool_to_option_value(bits.get(m37)),
                    );
                    pay.set_policy(
                        PolicyOptionName::AuditEnableRopCallerCheck,
                        Policy::bool_to_option_value(bits.get(m38)),
                    );
                    pay.set_policy(
                        PolicyOptionName::EnableRopSimExec,
                        Policy::bool_to_option_value(bits.get(m39)),
                    );
                    pay.set_policy(
                        PolicyOptionName::AuditEnableRopSimExec,
                        Policy::bool_to_option_value(bits.get(m40)),
                    );
                }
            }

            //
            // SEHOP Policy (4 bytes)
            //
            let mut buf_seh = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessSEHOPPolicy,
                &mut buf_seh,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_seh);
                let bits = BitVector32::new(val);
                let m = BitVector32::create_mask();

                if let Some(seh) = from_running_process.policies.get_mut(&PolicyName::SEHOP) {
                    seh.set_policy(
                        PolicyOptionName::Enable,
                        Policy::bool_to_option_value(bits.get(m)),
                    );
                }
            }

            //
            // ChildProcess Policy (4 bytes)
            //
            let mut buf_child = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessChildProcessPolicy,
                &mut buf_child,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_child);
                let bits = BitVector32::new(val);
                let m41 = BitVector32::create_mask();
                let m42 = BitVector32::next_mask(m41);

                if let Some(child) = from_running_process
                    .policies
                    .get_mut(&PolicyName::ChildProcess)
                {
                    child.set_policy(
                        PolicyOptionName::DisallowChildProcessCreation,
                        Policy::bool_to_option_value(bits.get(m41)),
                    );
                    child.set_policy(
                        PolicyOptionName::Audit,
                        Policy::bool_to_option_value(bits.get(m42)),
                    );
                }
            }

            //
            // UserShadowStack Policy (4 bytes)
            //
            let mut buf_uss = [0u8; 4];
            if Self::get_process_mitigation_policy(
                process_handle,
                ProcessMitigationPolicy::ProcessUserShadowStackPolicy,
                &mut buf_uss,
            )
            .is_ok()
            {
                let val = i32::from_le_bytes(buf_uss);
                let bits = BitVector32::new(val);
                let m43 = BitVector32::create_mask();
                let m44 = BitVector32::next_mask(m43);
                let m45 = BitVector32::next_mask(m44);
                let m46 = BitVector32::next_mask(m45);
                let m47 = BitVector32::next_mask(m46);
                let m48 = BitVector32::next_mask(m47);
                let m49 = BitVector32::next_mask(m48);
                let m50 = BitVector32::next_mask(m49);

                if let Some(uss) = from_running_process
                    .policies
                    .get_mut(&PolicyName::UserShadowStack)
                {
                    uss.set_policy(
                        PolicyOptionName::UserShadowStack,
                        Policy::bool_to_option_value(bits.get(m43)),
                    );
                    uss.set_policy(
                        PolicyOptionName::UserShadowStackStrictMode,
                        Policy::bool_to_option_value(bits.get(m47)),
                    );
                    uss.set_policy(
                        PolicyOptionName::AuditUserShadowStack,
                        Policy::bool_to_option_value(bits.get(m44)),
                    );
                    uss.set_policy(
                        PolicyOptionName::SetContextIpValidation,
                        Policy::bool_to_option_value(bits.get(m45)),
                    );
                    uss.set_policy(
                        PolicyOptionName::AuditSetContextIpValidation,
                        Policy::bool_to_option_value(bits.get(m46)),
                    );
                    uss.set_policy(
                        PolicyOptionName::BlockNonCetBinaries,
                        Policy::bool_to_option_value(bits.get(m48)),
                    );
                    uss.set_policy(
                        PolicyOptionName::BlockNonCetBinariesNonEhcont,
                        Policy::bool_to_option_value(bits.get(m49)),
                    );
                    uss.set_policy(
                        PolicyOptionName::AuditBlockNonCetBinaries,
                        Policy::bool_to_option_value(bits.get(m50)),
                    );
                }
            }
        }

        Ok(from_running_process)
    }
}
